#include<bits/stdc++.h>
#include<armadillo>
#include "Cell.h"

/*
 * compile and run: make 2D_CN && ./run.sh 2D_CN
 *
 * Metodo Crank-Nicolson
*/

using namespace std;
using namespace arma;

#define db double
#define PI 3.14159265

void load_Matrix_A(mat &A, int Nx, int Ny,db Sx, db Sy){
  db diag = 2.0*Sx + 2.0*Sy + 1.0;
  int node,upper,lower,prev,next;

  for(int i=1; i<=Ny; i++){      // iterate over rows
    for(int j=1; j<=Nx; j++){    //iterate over cols
      node = (i-1)*Nx + (j-1);
      upper = node + (Nx);
      lower = node - (Nx);
      prev = node - 1;
      next = node + 1;
      A(node, node) =  diag;

      if(j==1 && i==1){         // bottom-left
        A(node,next) = -Sx;
        A(node,upper) = -Sy;
      }else if(j==1 && i==Ny){  // top-left
        A(node,next) = -Sx;
        A(node,lower) = -Sy;
      }else if(j==1){           // left-middle
        A(node,next) = -Sx;
        A(node,lower) = -Sy;
        A(node,upper) = -Sy;
      }else if(j==Nx && i==1){  // bottom-right
        A(node,prev) = -Sx;
        A(node,upper) = -Sy;
      }else if(j==Nx && i==Ny){ // top-right
        A(node,prev) = -Sx;
        A(node,lower) = -Sy;
      }else if(i==1){           // bottom-middle
        A(node,prev) = -Sx;
        A(node,next) = -Sx;
        A(node,upper) = -Sy;
      }else if(j==Nx){          // right_middle
        A(node,prev) = -Sx;
        A(node,lower) = -Sy;
        A(node,upper) = -Sy;
      }else if(i==Ny){          // top-middle
        A(node,prev) = -Sx;
        A(node,next) = -Sx;
        A(node,lower) = -Sy;
      }else{                    // central nodes
        A(node,prev) = -Sx;
        A(node,next) = -Sx;
        A(node,lower) = -Sy;
        A(node,upper) = -Sy;
      }
    }
  }
}

void copy_voltage(vector<Cell> &cells, vec &X, vec &prevV,int Nx){
  int idx = Nx;
  for(int i=0; i<X.size(); i++){
    idx =(i%Nx==0)? idx+3 : idx+1 ;
    cells[idx].V = X(i);
    prevV(idx) = X(i);
  }
}

// Imprime los voltages calculados por DF en un tiempo determinado.
void print_solutions(vec &X,db t){
  cout<<t;
  for(int i=0; i<X.n_rows; i++)
    cout<<"  "<<X(i);
  cout<<endl;
}

int main(int argc, char* argv[]){
  clock_t startinit = clock();
  db nrepeat;     // num of beats
  db tbegin;      // init Stim time
  db BCL;         //
  db CI;
  db dt;          // time step
  db dtstim;      // Stim duration
  db CurrStim;    // Stinm current
  db cell_type;
  int nstp_prn;   // print result frecuency
  db tend;
  db nstp;
  int cell_to_stim;
  db deltaX,deltaY;
  int Nx,Ny,nodes,nodesA;
  db Dx,Dy,Gx,Gy;
  db Sx,Sy;
  int upper,lower,prev,next, pos,i,j;
  db Iion;
  db Jion;        // current density
  db cont_repeat = 0;
  db t = 0.0;
  int flag_stm = 1;
  db Istim = 0.0;
//-------------------------------------
  nrepeat = 1;   //60-> 1min, 600-> 10min
  tbegin = 50;
  BCL =  600;
  CI = 0;
  dtstim = 2;
  CurrStim = -8000;
  cell_type = 1;
  nstp_prn = 20;
  tend = tbegin+dtstim;
  Nx = atoi(argv[1]);
  Ny = atoi(argv[2]);
  cell_to_stim = 47;   // 70 in plot
  dt = 0.02; //ms
  deltaX = deltaY = 0.025;/// cm
//-------------------------------------

  db row_to_stim = 4;
  db begin_cell = row_to_stim*(Nx+2) + 1;

  nstp = (tbegin+BCL*nrepeat+CI)/dt;
  nodes = (Nx+2)*(Ny+2);                 // nodes including boundary conditions
  nodesA = Nx*Ny;                        // nodes calculated in matrix A.

  vector<Cell> cells(nodes);
  db areaT = cells[0].pi*pow(cells[0].a,2);  // Capacitive membrane area
  db aCm = cells[0].Cap / areaT;             // Capacitance per unit area pF/cm^2

  Dx = Dy = cells[0].a / (2.0*cells[0].Ri*aCm*1e-9); //D = 0.00217147 cm^2/ms

  Sx = (dt*Dx)/(2.0*pow(deltaX,2));
  Sy = (dt*Dy)/(2.0*pow(deltaY,2));

//-------------------------------------
  mat A = mat(nodesA,nodesA);         // A
  vec B = vec(nodesA);                // B
  vec X = vec(nodesA);                // X from AX=B;
  vec prevV = vec(nodes);             // Voltages of T time
//-------------------------------------

  prevV.fill(-81.2);
  load_Matrix_A(A, Nx, Ny, Sx, Sy);

  //var for printing only the last ncharts beats
  int ncharts = 1;
  int time_to_print = nstp- ((ncharts*BCL+tbegin)/dt);
  nstp = -1;
  clock_t endinit = clock();
  clock_t startfor = clock();
  for(int k=0; k<nstp+2; k++,t+=dt){ //each time
    pos = 0;
    if(t>=tbegin && t<=tend){
      flag_stm = 0.0;
    }else{
      if(flag_stm==0.0){
        if(cont_repeat < nrepeat){
          tbegin=tbegin+BCL; //se establece el tiempo del próximo estimulo
        }else if(cont_repeat == nrepeat) tbegin=tbegin+CI;

        cont_repeat++;
        tend=tbegin+dtstim;
        flag_stm = 1.0;
      }
    }
    for(int node=Nx+3; node<(nodes-(Nx+3)); node++){
      db BC = 0;       // boundary condition
      db rhs = 0;      // rigth hand side

      upper = node + (Nx+2);
      lower = node - (Nx+2);
      prev = node - 1;
      next = node + 1;
      j = node % (Nx+2);        //pos in x -> cols
      i = node / (Nx+2);        //pos in y -> rows

      // Estimuando toda una fila de celulas
      if(!flag_stm && (node >= begin_cell && node <= begin_cell + Nx -1)){
        Istim = CurrStim;
      }
      else{
        Istim = 0.0;
      }

      if(j>0 && j<(Nx+1)){
        Iion = cells[node].getItot(dt);

        if(j==1 && i==1){                           //bottom-left
          BC = Sx*prevV(prev) + Sy*prevV(lower);
        }else if(i==1 && j==Nx){                    //bottom-right
          BC = Sx*prevV(next) + Sy*prevV(lower);
        }else if(i==1){                             //bottom-middle
          BC = Sy*prevV(lower);
        }else if(i==Ny && j==1){                    //top-left
          BC = Sx*prevV(prev) + Sy*prevV(upper);
        }else if(i==Ny && j==Ny){                   //top-right
          BC = Sx*prevV(next) + Sy*prevV(upper);
        }else if(i==Ny){                            //top-middle
          BC = Sy*prevV(upper);
        }else if(j==1){                             //left-middle
          BC = Sx*prevV(prev);
        }else if(j==Nx){                            //right-middle
          BC = Sx*prevV(next);
        }

        rhs = Sx*prevV(prev) + (1.0-2.0*Sx-2.0*Sy)*prevV(node) + Sx*prevV(next) + Sy*prevV(lower) + Sy*prevV(upper);
        Jion = (Iion+Istim)/areaT;
        B(pos++) = rhs + BC - ((Jion)*dt/aCm);
      }
    }
    // armadillo solver for AX=B
    //clock_t start = clock();
    X = solve(A,B);
    //clock_t end = clock();
    //db elapsed_seconds = end - start;
    //printf("%lf\n",(elapsed_seconds / CLOCKS_PER_SEC));

    //copy_voltage(cells,X,prevV,Nx);
    //if(k%nstp_prn==0 && k>time_to_print)   //use this for plot last beat
    //  print_solutions(X,t);
  }
  clock_t endfor = clock();
  db timefor = endfor - startfor;
  printf("%lf\n",(timefor / CLOCKS_PER_SEC));
  db timeinit = endinit - startinit;
  printf("%lf\n",(timeinit / CLOCKS_PER_SEC));
  
  return 0;
}
